#!/net/data/search-fairness/.venv/bin/python

from functools import partial
import re

from sentence_transformers import SentenceTransformer
import numpy as np
import pandas as pd
from scipy.spatial.distance import cosine
import statsmodels.stats.api as sm
import fire
import matplotlib.pyplot as plt
import seaborn as sns

FP_QRYS = "/net/data/search-fairness/data/labels/qrys.npy"
FP_VECS = "/net/data/search-fairness/data/labels/embeddings.npy"


def load_qrys(
    fp_qrys: str = "/net/data/search-fairness/data/image_search_browser_history.csv",
    fp_weights: str = "/home/rer/proj/webusage/data/survey/DART0034_OUTPUT.DTA",
    fp_labels: str = "/net/data/search-fairness/data/labels/labels_final.csv",
):

    # image search query data
    qrys = pd.read_csv(
        fp_qrys,
        usecols=["qry_id", "qry", "url", "user_id", "timestamp", "domain"],
        parse_dates=["timestamp"],
    )
    qrys = qrys.loc[qrys.qry.notna(), :]

    # merge survey data
    survey = pd.read_stata(fp_weights)
    survey = survey.loc[
        survey.samplegroup == "CCES 2018 recontact",
        [
            "session_visa",
            "weight",
            "birthyr",
            "pid3",
            "educ",
            "faminc_new",
            "gender",
            "race",
        ],
    ]
    survey["age"] = 2020 - survey.birthyr
    survey["race"] = survey.race.replace(
        {"Middle Eastern": "Other", "Native American": "Other"}
    )

    qrys = pd.merge(
        qrys, survey, left_on="user_id", right_on="session_visa", validate="m:1"
    )

    # merge manual labels of people queries
    labels = pd.read_csv(fp_labels).drop(columns="qry")
    qrys = pd.merge(qrys, labels, on="qry_id", validate="m:1", how="left")
    return qrys


def vectorize(model_name: str = "all-MiniLM-L6-v2"):
    qrys = load_qrys().qry.unique()
    model = SentenceTransformer(model_name)
    embeddings = model.encode(qrys)
    np.save(FP_QRYS, qrys)
    np.save(FP_VECS, embeddings)


def prep_qrys(qrys):
    qrys = qrys.sort_values(by="timestamp")
    # drop consecutive identical queries (interactions with image results)
    qrys = qrys.groupby("user_id").apply(lambda df: df.loc[df.qry.shift(-1) != df.qry])
    qrys = qrys.reset_index(drop=True)
    return qrys


def load_vecs():
    qrys = np.load(FP_QRYS, allow_pickle=True)
    vecs = np.load(FP_VECS, allow_pickle=True)
    vec_dict = {qry: vec for qry, vec in zip(qrys, vecs)}
    return vec_dict


def add_sessions(df, vec_dict):
    df["time_diff"] = df.timestamp.diff().dt.total_seconds()
    qry_vecs = [vec_dict[qry] for qry in df.qry]
    df["cos_sim"] = [0] + [
        1 - cosine(q, q_next) for q, q_next in zip(qry_vecs[:-1], qry_vecs[1:])
    ]
    return df


def get_sessions_time(df, thresh=30):
    df[f"session_time_{thresh}"] = (df["time_diff"] >= thresh * 60).cumsum()
    return df


def get_sessions_sim(df, thresh=0.7):
    df[f"session_sim_{thresh}"] = (df["cos_sim"] < thresh).cumsum()
    return df


def add_session_idx(df):
    df["session_idx"] = range(len(df))
    return df


def all_sessions(data, sesh_func, thresh_list):
    sessions = []
    for thresh in thresh_list:
        sesh = (
            data.groupby("user_id", group_keys=False)
            .apply(partial(sesh_func, thresh=thresh))
            .reset_index(drop=True)
        )
        if sesh_func == get_sessions_time:
            s_col = f"session_time_{thresh}"
            sesh = sesh.groupby(["user_id", s_col], group_keys=False).apply(
                add_session_idx
            )
        elif sesh_func == get_sessions_sim:
            s_col = f"session_sim_{thresh}"
            short_sesh = sesh.groupby(["user_id", s_col], group_keys=False).filter(
                lambda df: len(df) == 1
            )
            short_sesh["session_idx"] = 0
            sesh_merge = sesh.merge(
                short_sesh[["user_id", s_col, "url"]], on=["user_id", s_col], how="left"
            )
            long_sesh = sesh_merge[sesh_merge.url_y.isna()]
            long_sesh = long_sesh.drop(columns="url_y").rename(columns={"url_x": "url"})
            long_sesh = long_sesh.groupby(["user_id", s_col], group_keys=False).apply(
                add_session_idx
            )
            sesh = pd.concat((short_sesh, long_sesh))
        sessions.append(sesh)
    return sessions


def compare(data_list, thresh_list):
    first_q = [x[x.category_nlp.notna() & (x.session_idx == 0)] for x in data_list]
    refined_q = [x[x.category_nlp.notna() & (x.session_idx > 0)] for x in data_list]

    data = []
    for fq, rq, t in zip(first_q, refined_q, thresh_list):
        lo, hi = sm.confint_proportions_2indep(
            rq.adjective_race_gender.notna().sum(),
            len(rq),
            fq.adjective_race_gender.notna().sum(),
            len(fq),
        )
        data.append({"lo": lo, "hi": hi, "thresh": t})

    return pd.DataFrame(data)


def plot(
    plot_data, min_thresh=0.1, fp_out="/net/data/search-fairness/plots/adjectives.pdf"
):
    plot_data = plot_data.melt(id_vars="thresh")
    plot_data = plot_data[plot_data.thresh >= min_thresh].round(2)
    plt.figure(figsize=(5, 3))
    plt.gca().tick_params(which="both", bottom=False, left=False, right=False)
    plt.tight_layout()
    plt.grid(linestyle=":")
    g = sns.pointplot(data=plot_data, x="thresh", y="value", linestyles="none")
    g.set(
        xlabel="Cosine Similarity Threshold",
        ylabel="Difference in $Pr$(demographic word)\n(refined query - initial query)",
    )
    for spine in ("top", "right", "bottom", "left"):
        plt.gca().spines[spine].set_visible(False)
    plt.axhline(y=0, color="r", linestyle="--")
    plt.savefig(fp_out, bbox_inches="tight")


def check_adj_match(row, adj_res, cat):
    adj = row.adjective_race_gender
    res = adj_res[cat]
    matches = sum([re.search(r, adj) is not None for r in res])
    if matches > 0:
        return 1
    else:
        return 0


def prep_r_data(data, fp_r_data="/net/data/search-fairness/data/adjectives.csv"):

    match_adjs = {
        "Male": ["man", "men", "male", "boy"],
        "Female": ["woman", "women", "female", "girl"],
        "Black": ["black"],
        "White": ["white"],
    }
    adj_res = {
        cat: [re.compile(f"(^|\s){adj}($|\s)") for adj in adjs]
        for cat, adjs in match_adjs.items()
    }

    m = data.category_nlp.notna()
    adj_qrys = data.loc[m, :].copy()
    adj_qrys["adjective_race_gender"] = adj_qrys.adjective_race_gender.fillna("")

    for cat in match_adjs.keys():
        adj_qrys[f"adj_{cat.replace(' ', '_')}"] = adj_qrys.apply(
            partial(check_adj_match, adj_res=adj_res, cat=cat), axis=1
        )

    adj_qrys[f"is_Male"] = adj_qrys.gender == "Male"
    adj_qrys[f"is_Female"] = adj_qrys.gender == "Female"
    for cat in adj_qrys.race.unique():
        adj_qrys[f"is_{cat.replace(' ', '_')}"] = adj_qrys.race == cat

    cols = [
        "adj_Female",
        "adj_Male",
        "adj_Black",
        "adj_White",
        "is_Female",
        "is_Male",
        "is_Black",
        "is_Hispanic",
        "is_Asian",
        "is_Two_or_more_races",
        "user_id",
    ]
    adj_qrys[cols].to_csv(fp_r_data, index=False)


def main():
    qrys = load_qrys()
    qrys = prep_qrys(qrys)
    vec_dict = load_vecs()
    data = qrys.groupby("user_id", group_keys=False).apply(
        partial(add_sessions, vec_dict=vec_dict)
    )

    prep_r_data(data)

    thresh_list = np.arange(0.1, 1, 0.1)
    qry_sim = all_sessions(data, get_sessions_sim, thresh_list)
    plot_data = compare(qry_sim, thresh_list)
    plot(plot_data)


if __name__ == "__main__":
    fire.Fire()
